### SPACEWHALE PROJECT
### S1 File
### File: gen_training_patches.py
### Alex Borowicz, Grant Humphries
### contact: aborowicz@coa.edu
### Find all code at https://github.com/aborowicz/spacewhale/
### This script takes in a large image, like a satellite image
### or an aerial survey image, and chops it into smaller tiles.
### These tiles can overlap (by making --step smaller than --size)
### or can line up with no overlap (by making --step and --size the same)
###########################################################################
### Usage:
# --root is the dir where the image(s) are located
# --step is the step between tiles, so start a new tile every __ pixels
# --size is the size of the tile in px (always square)
# --output is the dir to write the patches in
###########################################################################
# Example (making non-overlapping 32 x 32 pixel tiles):
# python gen_training_patches.py --root ./imagery/hawaii/ --step 32 --size 32 --output ./tiled_hawaii
###########################################################################

import numpy as np
from PIL import Image
import time
import torch
import os.path
import argparse
from scipy import misc
from m_util import *

### Call in the spacewhale class from m_util.py
s = spacewhale()

### Add arguments
parse = argparse.ArgumentParser()
parse.add_argument('--root',type=str,default='./Water_Training')
parse.add_argument('--step',type=int,default=500)
parse.add_argument('--size',type=int,default=30)
parse.add_argument('--output',type=str,default='./water')
opt = parse.parse_args()
opt.im_fold = opt.root
opt.results = opt.output

s.sdmkdir(opt.results)
opt.input_nc =3
imlist=[]
imnamelist=[]

### Making a list of images
for root,_,fnames in sorted(os.walk(opt.root)):
    for fname in fnames:
        if fname.lower().endswith('.png'):
            path = os.path.join(root,fname)
            imlist.append((path,fname))
            imnamelist.append(fname)
            
### Go through those images, call savepatch_train to tile them (from m_util)		
for im_path,imname in  imlist:
    png = misc.imread(im_path,mode='RGB')
    w,h,z = png.shape

    s.savepatch_train(png,w,h,opt.step,opt.size,opt.results+'/'+imname[:-4]+'#')
